# -*- coding: utf-8 -*-
# BSD 3-Clause License
#
# Copyright (c) 2017
# All rights reserved.
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ==========================================================================

#!/usr/bin/env python

import os
from pathlib import Path
from typing import Dict, List

import fire
import torch

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.utils.logging import get_logger


logger = get_logger(__name__)


def remove_prefix(text: str, prefix: str):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text  # or whatever


def sanitize(sd):
    return {remove_prefix(k, "model."): v for k, v in sd.items()}


def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]):
    new_sd = {}
    for k in state_dicts[0].keys():
        tensors = [sd[k] for sd in state_dicts]
        new_t = sum(tensors) / len(tensors)
        assert isinstance(new_t, torch.Tensor)
        new_sd[k] = new_t
    return new_sd


def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None:
    """Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict.
    Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once!

    Args:
        pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files.
            If a directory is passed, all .ckpt files inside it will be averaged!
        hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint
        save_path (:obj:`str`): Directory to save the new model

    """
    hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir)
    if os.path.isfile(pl_ckpt_path):
        ckpt_files = [pl_ckpt_path]
    else:
        assert os.path.isdir(pl_ckpt_path)
        ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt"))
        assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory"

    if len(ckpt_files) > 1:
        logger.info(f"averaging the weights of {ckpt_files}")

    state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files]
    state_dict = average_state_dicts(state_dicts)

    missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
    assert not missing, f"missing keys: {missing}"
    hf_model.save_pretrained(save_path)
    try:
        tok = AutoTokenizer.from_pretrained(hf_src_model_dir)
        tok.save_pretrained(save_path)
    except Exception:
        pass
        # dont copy tokenizer if cant


if __name__ == "__main__":
    fire.Fire(convert_pl_to_hf)
